{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Keras Model Preparer\n",
    "\n",
    "This notebook shows how to prepare a Keras model for quantization. Specifically, this preparer converts a Keras model with subclass layers into a Keras model with functional layers. This is required for quantization because the AIMET quantization tooling only supports the Functional and Sequantial Keras model building API's.\n",
    "\n",
    "To learn more about the Keras Model Preparer, please refer to the API Docs in AIMET.\n",
    "\n",
    "#### Overall flow\n",
    "This notebook covers the following\n",
    "1. Creating a Keras model with subclass layers\n",
    "2. Converting the Keras model with subclass layers to a Keras model with functional layers\n",
    "3. Showing similarities and differences between the original and converted models\n",
    "4. Dicussing the limitations of the Keras Model Preparer"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "## 1. Creating a Keras model with subclass layers\n",
    "\n",
    "First, we will create a Keras model with subclass layers. For this notebook example, we will use a model defined by Keras that utilizes subclass layers. This model is a text classification transformer model and can be found [here]( https://keras.io/examples/nlp/text_classification_with_transformer/). The subclass layers used in this model are - `TokenAndPositionEmbedding` and `TransformerBlock`. They are defined below. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "class TransformerBlock(tf.keras.layers.Layer):\n",
    "    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):\n",
    "        super(TransformerBlock, self).__init__()\n",
    "        self.att = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)\n",
    "        self.ffn = tf.keras.Sequential(\n",
    "            [tf.keras.layers.Dense(ff_dim, activation=\"relu\"), tf.keras.layers.Dense(embed_dim),]\n",
    "        )\n",
    "        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)\n",
    "        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)\n",
    "        self.dropout1 = tf.keras.layers.Dropout(rate)\n",
    "        self.dropout2 = tf.keras.layers.Dropout(rate)\n",
    "\n",
    "    def call(self, inputs, training, **kwargs):\n",
    "        attn_output = self.att(inputs, inputs)\n",
    "        attn_output = self.dropout1(attn_output, training=training)\n",
    "        out1 = self.layernorm1(inputs + attn_output)\n",
    "        ffn_output = self.ffn(out1)\n",
    "        ffn_output = self.dropout2(ffn_output, training=training)\n",
    "        return self.layernorm2(out1 + ffn_output)\n",
    "\n",
    "\n",
    "\n",
    "class TokenAndPositionEmbedding(tf.keras.layers.Layer):\n",
    "    def __init__(self, maxlen, vocab_size, embed_dim):\n",
    "        super(TokenAndPositionEmbedding, self).__init__()\n",
    "        self.token_emb = tf.keras.layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)\n",
    "        self.pos_emb = tf.keras.layers.Embedding(input_dim=maxlen, output_dim=embed_dim)\n",
    "\n",
    "    def call(self, x, **kwargs):\n",
    "        maxlen = tf.shape(x)[-1]\n",
    "        positions = tf.range(start=0, limit=maxlen, delta=1)\n",
    "        positions = self.pos_emb(positions)\n",
    "        x = self.token_emb(x)\n",
    "        x = x + positions\n",
    "        return x"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With those subclass layers defined, we can now define the model. Since we are not training the model, we will use random weights and a random input tensor to build the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "vocab_size = 20000 \n",
    "maxlen = 200\n",
    "\n",
    "random_input = np.random.random((10, 200)) # Random input to build the model\n",
    "\n",
    "embed_dim = 32  # Embedding size for each token\n",
    "num_heads = 2  # Number of attention heads\n",
    "ff_dim = 32  # Hidden layer size in feed forward network inside transformer\n",
    "\n",
    "inputs = tf.keras.layers.Input(shape=(maxlen,))\n",
    "embedding_layer = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim)\n",
    "x = embedding_layer(inputs)\n",
    "transformer_block = TransformerBlock(embed_dim, num_heads, ff_dim)\n",
    "x = transformer_block(x)\n",
    "x = tf.keras.layers.GlobalAveragePooling1D()(x)\n",
    "x = tf.keras.layers.Dropout(0.1)(x)\n",
    "x = tf.keras.layers.Dense(20, activation=\"relu\")(x)\n",
    "x = tf.keras.layers.Dropout(0.1)(x)\n",
    "outputs = tf.keras.layers.Dense(2, activation=\"softmax\")(x)\n",
    "\n",
    "model = tf.keras.Model(inputs=inputs, outputs=outputs)\n",
    "_ = model(random_input)\n",
    "model.summary()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "From the `model.summary()` output, we can see the models 2 subclass layers - `token_and_position_embedding`, `transformer_block`. Since these layers are using layer inside they're classes, we need to extract them to create a symmetrical functional model. "
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "## 2. Converting the Keras model with subclass layers to a Keras model with functional layers\n",
    "\n",
    "The Keras Model Preparer can be used to convert a Keras model with subclass layers to a Keras model with functional layers. The Keras Model Preparer can be imported from `aimet_tensorflow.keras.model_preparer`. The Keras Model Preparer takes in a Keras model with subclass layers and returns a Keras model with functional layers. Note that the `prepare_model` function takes an optional `input_layer` parameter. This parameter is required if the model begins with a subclass layer. In this case, the model does not begin with a subclass layer, so we do not need to provide an `input_shape` parameter."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from aimet_tensorflow.keras.model_preparer import prepare_model\n",
    "\n",
    "functional_model = prepare_model(model) \n",
    "functional_model.summary()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see that the Keras Model Preparer has converted the model with subclass layers to a model with functional layers. Specifically, it has extracted the call function of each of these layers and created a functional layer from it."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "## 3. Showing similarities and differences between the original and converted models\n",
    "\n",
    "We can see that the original model and the converted model are symmetrical. The only difference is that the subclass layers are unwrapped. This means that the converted model is functionally identical to the original model. We can test this in a few ways.\n",
    "\n",
    "1) We can compare the total number of parameters in the original and converted models. We can see that the total number of parameters is the same.\n",
    "\n",
    "2) We can compare the weights of the original and converted models. We can see that the weights are the same.\n",
    "    * Note that the order of the weights presented when calling `get_weights()` on each of these models are not the same and as is the names of the weights. We can use an internal function to get the original models weights in the same order as the converted models weights.\n",
    "\n",
    "3) We can compare the outputs of the original and converted models. We can see that the outputs are the same."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Set, List\n",
    "\n",
    "# This function is a functional representation of the reordering done inside the model preparer for Keras.\n",
    "def get_original_models_weights_in_functional_model_order(\n",
    "    original_model: tf.keras.Model,\n",
    "    functional_model: tf.keras.Model,\n",
    "    class_names: Set[str]\n",
    ") -> List[np.ndarray]:\n",
    "    \"\"\"Map the original model's weights to the functional model's weights.\n",
    "\n",
    "    Args:\n",
    "        original_model:\n",
    "            Original model to reference the weight order.\n",
    "        functional_model: \n",
    "            Prepared model to updates weight of.\n",
    "        class_names: \n",
    "            Names of the classes that the original model was subclassed from\n",
    "\n",
    "    Returns:\n",
    "        A list of the original model's weights in the order of the functional model's weights\n",
    "    \"\"\"\n",
    "\n",
    "    # Make the original model's weights into a dictionary for quick lookup by name\n",
    "    # The original subclassed layers names are removed to match the new functional model's names\n",
    "    original_model_weights = {}\n",
    "    for weight in original_model.weights:\n",
    "        # pop out class_names of weight name\n",
    "        weight_name = weight.name\n",
    "        for class_name in class_names:\n",
    "            weight_name = weight_name.replace(class_name + '/', '')\n",
    "        original_model_weights[weight_name] = weight.numpy()\n",
    "\n",
    "    # Get the functional model's weights in order as a dictionary for quick lookup where the key is the weight name\n",
    "    # and the position of the weight's order is the value\n",
    "    functional_model_weight_order = {\n",
    "        weight.name: position\n",
    "        for position, weight in enumerate(functional_model.weights)\n",
    "    }\n",
    "\n",
    "    # Using the functional model's weights order, get the original model's weights in the same order. The lambda here\n",
    "    # uses the weight's name to get position in the functional model's weights order and the sorts the original model's\n",
    "    # weights by that position.\n",
    "    weights_in_correct_order = [\n",
    "        weight for _, weight in\n",
    "        sorted(original_model_weights.items(), key=lambda weight_info: functional_model_weight_order[weight_info[0]])\n",
    "    ]\n",
    "\n",
    "    return weights_in_correct_order\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert functional_model.count_params() == model.count_params()\n",
    "assert functional_model.input_shape == model.input_shape\n",
    "assert functional_model.output_shape == model.output_shape\n",
    "\n",
    "# NOTE: Since TextClassification Model has the internal layers out of order compared to the call method,\n",
    "# the weights are not in the order of what the actual architecture is (this is a Keras design).\n",
    "# Therefore, we get the original model's weights and sort them in the order of the actual\n",
    "# architecture and use those weights to compare to the functional model's weights.\n",
    "model_weights_in_correct_order = get_original_models_weights_in_functional_model_order(\n",
    "    model, functional_model, class_names=[\"token_and_position_embedding\", \"transformer_block\"])\n",
    "\n",
    "for i, _ in enumerate(model_weights_in_correct_order):\n",
    "        np.testing.assert_array_equal(model_weights_in_correct_order[i], functional_model.get_weights()[i])\n",
    "\n",
    "np.testing.assert_array_equal(functional_model(random_input).numpy(), model(random_input).numpy())\n",
    "print(\"Models are equal\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Discussing the limitations of the Keras Model Preparer\n",
    "\n",
    "- The AIMET Keras ModelPreparer API is able to convert subclass layers that have arthmetic experssion in their call function.\n",
    "However, this API and Keras, will convert these operations to TFOPLambda layers which are not currently supported by AIMET Keras Quantization API. \n",
    "If possible, it is recommended to have the subclass layers call function ressemble the Keras Functional API layers.\n",
    "For example, if a subclass layer has two convolution layers in its call function, the call function should look like\n",
    "the following:\n",
    "\n",
    "    ```python\n",
    "    def call(self, x, **kwargs):\n",
    "        x = self.conv_1(x)\n",
    "        x = self.conv_2(x)\n",
    "        return x\n",
    "    ```\n",
    "\n",
    "- If the model starts with a subclassed layer, the AIMET Keras ModelPreparer API will need an Keras Input Layer as input.\n",
    "This is becuase the Keras Functional API requires an Input Layer as the first layer in the model. The AIMET Keras ModelPreparer API\n",
    "will raise an exception if the model starts with a subclassed layer and an Input Layer is not provided as input."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "## Summary\n",
    "\n",
    "Hopefully this notebook was useful for you to understand how to use the Keras Model Preparer.\n",
    "\n",
    "Few additional resources:\n",
    "- [AIMET API Docs](https://quic.github.io/aimet-pages/releases/latest/user_guide/index.html)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "vscode": {
   "interpreter": {
    "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
